//
//  crop.metal
//  EffectMgrMetal
//
//  Created by WS on 2021/5/26.
//  Copyright © 2021 WS. All rights reserved.
//

#include <metal_stdlib>
using namespace metal;

#define CLAMP(v, min, max) \
    if (v < min) { \
        v = min; \
    } else if (v > max) { \
        v = max; \
    }

static float4 GetPixelClamped(texture2d<float, access::read> in [[texture(0)]], uint x, uint y, int inW, int inH) {
    CLAMP(x, 0, inW - 1)
    CLAMP(y, 0, inH - 1)
    return in.read(uint2(x, y));
}

static float Lerp (float A, float B, float t) {
    return A * (1.0f - t) + B * t;
}

static float4 SampleBilinear (texture2d<float, access::read> in [[texture(0)]],
                       float u, float v, int inW, int inH) {
    // calculate coordinates -> also need to offset by half a pixel to keep image from shifting down and left half a pixel
    float x = u * float(inW) - 0.5f;
    int xint = int(x);
    float xfract = x - floor(x);
    
    float y = v * float(inH) - 0.5f;
    int yint = int(y);
    float yfract = y - floor(y);
    
    // get pixels
    auto p00 = GetPixelClamped(in, xint + 0, yint + 0, inW, inH);
    auto p10 = GetPixelClamped(in, xint + 1, yint + 0, inW, inH);
    auto p01 = GetPixelClamped(in, xint + 0, yint + 1, inW, inH);
    auto p11 = GetPixelClamped(in, xint + 1, yint + 1, inW, inH);
    
    // interpolate bi-linearly!
    float4 ret;
    for (int i = 0; i < 4; ++i)
    {
        float col0 = Lerp(p00[i], p10[i], xfract);
        float col1 = Lerp(p01[i], p11[i], xfract);
        float value = Lerp(col0, col1, yfract);
        CLAMP(value, 0.0f, 255.0f);
        ret[i] = value;
    }
    return ret;
}

kernel void crop(texture2d<float, access::read> in [[texture(0)]],
                texture2d<float, access::write> out [[texture(1)]],
                constant int *inW [[buffer(0)]],
                constant int *inH [[buffer(1)]],
                constant int *outW [[buffer(2)]],
                constant int *outH [[buffer(3)]],
                constant float *roi_x[[buffer(4)]],
                constant float *roi_y[[buffer(5)]],
                uint2 gid [[thread_position_in_grid]])
{
    float ow = float(*inW);
    float oh = float(*inH);
    float dstW = float(*outW);
    float dstH = float(*outH);
    float u = float(gid.x) / float(*outW - 1);
    float v = float(gid.y) / float(*outH - 1);
//    float u = (float(gid.x) / float(*outW - 1)) * float(*outW) / float(*inW);
//    float v = (float(gid.y) / float(*outH - 1)) * float(*outH)/ float(*inH);
    float2 tc = float2(u,v);
    float x = tc.x * dstW / ow + (*roi_x);
    float y = tc.y * dstH / oh + (*roi_y);
    
    float2 uv = float2(x,y);
    float grid = (step(0.0,uv.x) - step(1.0,uv.x)) * (step(0.0,uv.y) - step(1.0,uv.y));
    float4 ovlColor = SampleBilinear(in, uv.x, uv.y, *inW, *inH) * grid;
    out.write(ovlColor, gid);
}
